# Copyright 2020 ByteDance Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from absl import logging

from neurst.criterions import build_criterion, register_criterion
from neurst.criterions.criterion import Criterion
from neurst.utils.flags_core import Flag


@register_criterion
class JointCriterion(Criterion):

    def __init__(self, args):
        """ Initializes the cross entropy with label smoothing.

        Args:
            args: A dict of full parameters.
        """
        super(JointCriterion, self).__init__()
        assert isinstance(args["criterions"], list), (
            "`criterions` should be a list of multiple criterion settings for JointCriterion.")
        self._criterions = []
        self._alphas = []
        logging.info("Creating JointCriterion for training, which is composed by:")
        for crit in args["criterions"]:
            self._alphas.append(crit.pop("alpha", 1.0))
            self._criterions.append(build_criterion(crit))
            logging.info("  - {} with alpha={}".format(self._criterions[-1].__class__.__name__, self._alphas[-1]))

    @staticmethod
    def class_or_method_args():
        """ Returns a list of args for flag definition. """
        return [Flag("criterions", dtype=Flag.TYPE.STRING,
                     default=None, help="A list of multiple criterions. Each element should be a dict like "
                                        "\"{'criterion.class': '...', 'criterion.params': '...', 'alpha': 1.0}\","
                                        "where alpha denotes the weight of the corresponding criterion "
                                        "and will be set to 1.0 by default.")]

    def reduce_loss(self, model_inp, model_out):
        """ Reduces loss tensor for training according to the model inputs
            and outputs.

        Returns: A nest structure of tensors.
        """
        reduced_loss = dict()
        for alpha, criterion in zip(self._alphas, self._criterions):
            this_loss = criterion.reduce_loss(model_inp, model_out)
            # We now only consider to joint the criterion with single loss tensor.
            reduced_loss[criterion.__class__.__name__] = alpha * this_loss
        return reduced_loss

    def reduce_metrics(self, eval_res_list):
        """ Reduces the metrics according to a list of returned value from `eval`.

        Args:
            eval_res_list: A list of tuples of numpy.ndarray generated by `self.__call__`
                and model.__call__.

        Returns:
            A dict of reduced metrics for evaluation.
        """
        final_metric = dict()
        for criterion, eval_res_l in zip(self._criterions, list(map(list, zip(*eval_res_list)))):
            for k, v in criterion.reduce_metrics(eval_res_l).items():
                final_metric[criterion.__class__.__name__ + "/" + k] = v
        return final_metric

    def as_metric(self):
        """ Returns a wrapper class of Metric. """
        return self._criterions[0].as_metric()

    def __call__(self, model_inp, model_out):
        """ Returns a list of __call__ output corresponding to each criterion. """
        return [x(model_inp, model_out) for x in self._criterions]
